import os
import gzip
import pybedtools
from Bio import SeqIO
from Bio.Seq import reverse_complement


assembly = 'hg38'

def read_chromosome_names(assembly):
    chromosomes = ["chr%s" % i for i in list(range(1,23)) + ['X', 'Y']]
    names = {}
    directory = "/osc-fs_home/scratch/mdehoon/Data/Genomes/"
    filename = "%s.chrom.sizes" % assembly
    path = os.path.join(directory, assembly, filename)
    print("Reading", path)
    handle = open(path)
    for line in handle:
        chromosome, size = line.split()
        size = int(size)
        terms = chromosome.split("_")
        if len(terms) == 1:
            assert chromosome.startswith("chr")
            if chromosome == "chrM":
                key = "MT"
            else:
                key = chromosome[3:]
                assert chromosome in chromosomes
        elif len(terms) == 2:
            assert terms[0] == "chrUn"
            key, version = terms[1].split("v")
            key = "%s.%s" % (key, version)
        elif len(terms) == 3:
            assert terms[2] in ("alt", "random")
            key, version = terms[1].split("v")
            key = "%s.%s" % (key, version)
            assert terms[0] in chromosomes
        else:
            raise Exception("Unknown chromosome %s" % chromosome)
        names[key] = chromosome
    handle.close()
    return names

def read_chromosome_sizes(assembly):
    sizes = {}
    directory = "/osc-fs_home/scratch/mdehoon/Data/Genomes/"
    filename = "%s.chrom.sizes" % assembly
    path = os.path.join(directory, assembly, filename)
    handle = open(path)
    for line in handle:
        chromosome, size = line.split()
        assert chromosome.startswith("chr")
        size = int(size)
        sizes[chromosome] = size
    handle.close()
    return sizes

def read_gene_names():
    directory = "/osc-fs_home/scratch/mdehoon/Data/Ensembl/"
    filename = "Homo_sapiens.GRCh38.100.gtf.gz"
    path = os.path.join(directory, filename)
    print("Reading", path)
    handle = gzip.open(path, "rt")
    lines = pybedtools.BedTool(handle)
    names = {}
    for line in lines:
        if line.fields[2] == 'transcript':
            transcript_id = line.attrs['transcript_id']
            transcript_version = line.attrs['transcript_version']
            transcript = "%s.%s" % (transcript_id, transcript_version)
            gene_name = line.attrs['gene_name']
            names[transcript] = gene_name
    handle.close()
    return names

def read_rnacentral_annotations():
    directory = "/osc-fs_home/scratch/mdehoon/Data/RNAcentral"
    filename = "rnacentral.gpi.gz"
    path = os.path.join(directory, filename)
    print("Reading", path)
    handle = gzip.open(path, "rt")
    line = next(handle)
    assert line.strip() == '!gpi-version: 1.2'
    annotations = {}
    for line in handle:
        words = line.strip().split("\t")
        assert len(words) == 7
        assert words[0] == "RNAcentral"
        taxon, taxon_id = words[6].split(":")
        assert taxon == "taxon"
        if taxon_id != "9606":  # Homo sapiens:
            continue
        accession, taxon_id = words[1].split("_")
        assert taxon_id == "9606"
        category = words[5]
        annotation = words[3]
        if category in ("snRNA", "misc_RNA") or "spliceosomal" in annotation:
            annotations[accession] = annotation
    handle.close()
    return annotations

def read_rfam_annotations():
    directory = "/osc-fs_home/scratch/mdehoon/Data/RNAcentral"
    filename = "rfam_annotations.tsv.gz"
    path = os.path.join(directory, filename)
    print("Reading", path)
    handle = gzip.open(path, "rt")
    annotations = {}
    for line in handle:
        words = line.strip().split("\t")
        assert len(words) == 9
        accession = words[0]
        annotation = words[8]
        annotations[accession] = annotation
    handle.close()
    return annotations

def read_snrnas():
    gene_names = read_gene_names()
    rnacentral_annotations = read_rnacentral_annotations()
    rfam_annotations = read_rfam_annotations()
    path = '/osc-fs_home/scratch/mdehoon/Data/Genomes/hg38/hg38.2bit'
    print("Parsing", path)
    handle = open(path, 'rb')
    genome = SeqIO.parse(handle, "twobit")
    chromosomes = read_chromosome_names(assembly)
    intervals = []
    transcripts = {}
    descriptions = {}
    directory = "/osc-fs_home/scratch/mdehoon/Data/Ensembl"
    numbers = [str(i) for i in list(range(1,23)) + ['X', 'Y', 'MT']] + [None]
    for number in numbers:
        if number is None:
            filename = "Homo_sapiens.GRCh38.100.nonchromosomal.dat.gz"
        else:
            filename = "Homo_sapiens.GRCh38.100.chromosome.%s.dat.gz" % number
        path = os.path.join(directory, filename)
        print("Reading", path)
        handle = gzip.open(path, 'rt')
        records = SeqIO.parse(handle, 'genbank')
        gene_descriptions = {}
        for record in records:
            sequence = str(record.seq)
            assert len(record.annotations['accessions']) == 1
            accession = record.annotations['accessions'][0]
            terms = accession.split(":")
            assert len(terms) == 6
            if number is None:
                assert terms[0] in ('chromosome', 'scaffold')
            else:
                assert terms[0] == 'chromosome'
            assert terms[1] == 'GRCh38'
            if number is not None:
                assert terms[2] == number
            chromosome = chromosomes.get(terms[2])
            start = int(terms[3]) - 1
            end = int(terms[4])
            assert int(terms[5]) == 1
            assert end - start == len(sequence)
            offset = start
            for feature in record.features:
                if feature.type == 'gene':
                    notes = feature.qualifiers.get('note')
                    if notes is not None:
                        assert len(notes) == 1
                        note = notes[0]
                        genes = feature.qualifiers['gene']
                        assert len(genes) == 1
                        gene = genes[0]
                        gene_descriptions[gene] = note
                elif feature.type == 'misc_RNA':
                    notes = feature.qualifiers['note']
                    assert len(notes) == 1
                    if notes[0] != 'snRNA':
                        continue
                    names = feature.qualifiers['standard_name']
                    assert len(names) == 1
                    name = names[0]
                    gene_description = gene_names.get(name)
                    if gene_description is None:
                        genes = feature.qualifiers['gene']
                        assert len(genes) == 1
                        gene = genes[0]
                        gene_description = gene_descriptions.get(gene, "")
                    rnacentral_description = ""
                    rfam_description = ""
                    db_xrefs = feature.qualifiers['db_xref']
                    for db_xref in db_xrefs:
                        source, accession = db_xref.split(":")
                        if source == "RNAcentral":
                            rnacentral_description = rnacentral_annotations.get(accession, "")
                            rfam_description = rfam_annotations.get(accession, "")
                            break
                    description = "%s|%s|%s" % (gene_description, rnacentral_description, rfam_description)
                    descriptions[name] = description
                    start = int(feature.location.start)
                    end = int(feature.location.end)
                    strand = feature.location.strand
                    transcript = sequence[start:end]
                    if strand == 1:
                        strand = '+'
                    else:
                        assert strand == -1
                        strand = '-'
                        transcript = reverse_complement(transcript)
                    assert name not in transcripts
                    transcripts[name] = transcript
                    if chromosome is None:
                        continue
                    start += offset
                    end += offset
                    if strand == '+':
                        assert transcript == genome[chromosome][start:end].seq.upper()
                    else:
                        assert reverse_complement(transcript) == genome[chromosome][start:end].seq.upper()
                    length = end - start
                    lengths = "%s," % length
                    fields = [chromosome, start, end, name, "0", strand, start, end, "0", 1, lengths, "0,"]
                    interval = pybedtools.create_interval_from_list(fields)
                    intervals.append(interval)
        handle.close()
    intervals = pybedtools.BedTool(intervals)
    return intervals, transcripts, descriptions

def read_refseq_snrna_loci():
    path = "snRNA.bed"
    print("Reading", path)
    lines = pybedtools.BedTool(path)
    return lines

def read_refseq_snrna_transcripts():
    path = "snRNA.fa"
    print("Reading", path)
    records = SeqIO.parse(path, 'fasta')
    records = list(records)
    return records

sizes = read_chromosome_sizes(assembly)

intervals, sequences, descriptions = read_snrnas()
print("Total number of Ensembl snRNAs: %d" % len(sequences))
print("Total number of mapped Ensembl snRNAs: %d" % len(intervals))


refseq_intervals = read_refseq_snrna_loci()
overlapping_intervals = intervals.intersect(refseq_intervals, s=True)
overlapping = []
for interval in overlapping_intervals:
    name = interval.name
    overlapping.append(name)

merged_intervals = []
for interval in refseq_intervals:
    merged_intervals.append(interval)
for interval in intervals:
    name = interval.name
    if name in overlapping:
        continue
    merged_intervals.append(interval)

merged_intervals = pybedtools.BedTool(merged_intervals)
merged_intervals = merged_intervals.sort()

filename = "snRNA.psl"
print("Writing %s" % filename)
handle = open(filename, 'w')
for interval in merged_intervals:
    qSize = interval.end - interval.start
    blockCount = int(interval.fields[9])
    assert blockCount == 1
    matches = qSize
    misMatches = 0
    repMatches = 0
    nCount = 0
    qNumInsert = 0
    qBaseInsert = 0
    tNumInsert = 0
    tBaseInsert = 0
    strand = interval.strand
    qName = interval.name
    qStart = 0
    qEnd = qSize
    tName = interval.chrom
    tSize = sizes[tName]
    tStart = interval.start
    tEnd = interval.end
    blockSizes = "%d," % qSize
    qStarts = "0,"
    tStarts = "%d," % tStart
    fields = [qSize,
              misMatches,
              repMatches,
              nCount,
              qNumInsert,
              qBaseInsert,
              tNumInsert,
              tBaseInsert,
              strand,
              qName,
              qSize,
              qStart,
              qEnd,
              tName,
              tSize,
              tStart,
              tEnd,
              blockCount,
              blockSizes,
              qStarts,
              tStarts]
    line = "\t".join([str(field) for field in fields]) + "\n"
    handle.write(line)


handle.close()

records = read_refseq_snrna_transcripts()
filename = "snRNA.fa"
print("Writing %s" % filename)
handle = open(filename, 'w')
for record in records:
    handle.write(format(record, 'fasta'))

for name in sequences:
    if name in overlapping:
        continue
    sequence = sequences[name]
    description = descriptions[name]
    handle.write('>%s %s\n' % (name, description))
    handle.write('%s\n' % sequence)

handle.close()
